760c5c
@@ -22,6 +22,7 @@
 import org.apache.hadoop.hive.ql.exec.Description;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
 import org.apache.hadoop.hive.ql.exec.vector.VectorizedExpressions;
 import org.apache.hadoop.hive.ql.exec.vector.expressions.FuncRoundWithNumDigitsDecimalToDecimal;
 import org.apache.hadoop.hive.ql.exec.vector.expressions.RoundWithNumDigitsDoubleToDouble;
@@ -83,38 +84,42 @@
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen
           "ROUND requires one or two argument, got " + arguments.length);
     }
 
-    inputOI = (PrimitiveObjectInspector) arguments[0];
-    if (inputOI.getCategory() != Category.PRIMITIVE) {
-      throw new UDFArgumentException(
-          "ROUND input only takes primitive types, got " + inputOI.getTypeName());
+    if (arguments[0].getCategory() != Category.PRIMITIVE) {
+      throw new UDFArgumentTypeException(0,
+          "ROUND input only takes primitive types, got " + arguments[0].getTypeName());
     }
+    inputOI = (PrimitiveObjectInspector) arguments[0];
 
     if (arguments.length == 2) {
+      if (arguments[1].getCategory() != Category.PRIMITIVE) {
+        throw new UDFArgumentTypeException(1,
+            "ROUND second argument only takes primitive types, got " + arguments[1].getTypeName());
+      }
       PrimitiveObjectInspector scaleOI = (PrimitiveObjectInspector) arguments[1];
       switch (scaleOI.getPrimitiveCategory()) {
       case VOID:
         break;
       case BYTE:
         if (!(scaleOI instanceof WritableConstantByteObjectInspector)) {
-          throw new UDFArgumentException("ROUND second argument only takes constant");
+          throw new UDFArgumentTypeException(1, "ROUND second argument only takes constant");
         }
         scale = ((WritableConstantByteObjectInspector)scaleOI).getWritableConstantValue().get();
         break;
       case SHORT:
         if (!(scaleOI instanceof WritableConstantShortObjectInspector)) {
-          throw new UDFArgumentException("ROUND second argument only takes constant");
+          throw new UDFArgumentTypeException(1, "ROUND second argument only takes constant");
         }
         scale = ((WritableConstantShortObjectInspector)scaleOI).getWritableConstantValue().get();
         break;
       case INT:
         if (!(scaleOI instanceof WritableConstantIntObjectInspector)) {
-          throw new UDFArgumentException("ROUND second argument only takes constant");
+          throw new UDFArgumentTypeException(1, "ROUND second argument only takes constant");
         }
         scale = ((WritableConstantIntObjectInspector)scaleOI).getWritableConstantValue().get();
         break;
       case LONG:
         if (!(scaleOI instanceof WritableConstantLongObjectInspector)) {
-          throw new UDFArgumentException("ROUND second argument only takes constant");
+          throw new UDFArgumentTypeException(1, "ROUND second argument only takes constant");
         }
         long l = ((WritableConstantLongObjectInspector)scaleOI).getWritableConstantValue().get();
         if (l < Integer.MIN_VALUE || l > Integer.MAX_VALUE) {
@@ -123,7 +128,7 @@
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen
         scale = (int)l;
         break;
       default:
-        throw new UDFArgumentException("ROUND second argument only takes integer constant");
+        throw new UDFArgumentTypeException(1, "ROUND second argument only takes integer constant");
       }
     }
 
@@ -151,8 +156,9 @@
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen
       converterFromString = ObjectInspectorConverters.getConverter(inputOI, outputOI);
       break;
     default:
-      throw new UDFArgumentException("Only numeric data types are allowed for ROUND function. Got " +
-          inputType.name());
+      throw new UDFArgumentTypeException(0,
+          "Only numeric or string group data types are allowed for ROUND function. Got "
+              + inputType.name());
     }
 
     return outputOI;
@@ -240,8 +246,9 @@
public Object evaluate(DeferredObject[] arguments) throws HiveException {
        }
        return round(doubleValue, scale);
      default:
-       throw new UDFArgumentException("Only numeric data types are allowed for ROUND function. Got " +
-           inputType.name());
+      throw new UDFArgumentTypeException(0,
+          "Only numeric or string group data types are allowed for ROUND function. Got "
+              + inputType.name());
     }
   }
 
